import numpy as np
from ar_sim.common.kernel_builder import build_reproduction_kernel


import numpy as np
from ar_sim.common.kernel_builder import build_reproduction_kernel


def collapse_kernel_2d(field: np.ndarray,
                       n_vals: np.ndarray,
                       pivot_params: dict,
                       sigma: float = 1.0) -> np.ndarray:
    """
    Apply the 2D collapse kernel M_ij = g(D_i) * exp[-(n_i-n_j)^2/(2 sigma^2)]
    to a field vector. Enforces identity at D=2.

    Args:
      field: 1D array of length N (field values at context indices n_vals)
      n_vals: 1D array of context indices of length N
      pivot_params: dict with keys 'a', 'b', 'D_vals'
      sigma: kernel width parameter

    Returns:
      collapsed_field: 1D array of length N
    """
    # build the kernel matrix
    M = build_reproduction_kernel(n_vals, pivot_params, sigma)
    # apply to field
    return M.dot(field)


def collapse_kernel_4d(field: np.ndarray,
                       n_vals: np.ndarray,
                       pivot_params: dict,
                       l_max: int = 3,
                       sigma: float = 1.0) -> np.ndarray:
    """
    Placeholder for 4D collapse over the null‑cone hypersphere.
    Currently delegates to 2D collapse. Extend with hyperspherical harmonics.

    Args:
      field: 1D array of length N
      n_vals: 1D array of context indices
      pivot_params: dict with pivot calibration
      l_max: maximum spherical harmonic degree
      sigma: kernel width parameter

    Returns:
      collapsed_field: 1D array of length N
    """
    # TODO: implement true 4D collapse via spherical harmonics
    return collapse_kernel_2d(field, n_vals, pivot_params, sigma)


def tick_retarded(field: np.ndarray,
                  n_vals: np.ndarray) -> np.ndarray:
    """
    Retarded tick operator: shifts field by +1 context index.
    Uses a circular roll as a placeholder.

    Args:
      field: 1D array of length N
      n_vals: 1D array of context indices

    Returns:
      ticked_field: 1D array of length N
    """
    # simple roll by one
    return np.roll(field, 1)


def composite_moment_4d(field: np.ndarray,
                        n_vals: np.ndarray,
                        pivot_params: dict,
                        sigma: float = 1.0) -> np.ndarray:
    """
    Composite 4D moment operator: apply 4D collapse then retarded tick.

    Args:
      field: 1D array of length N
      n_vals: 1D array of context indices
      pivot_params: dict with pivot calibration
      sigma: kernel width parameter

    Returns:
      composite_field: 1D array of length N
    """
    collapsed = collapse_kernel_4d(field, n_vals, pivot_params, sigma=sigma)
    return tick_retarded(collapsed, n_vals)
